import unittest

import numpy as np

import causallearn.utils.cit as cit


# TODO : Design more comprehensive test cases, including: design dataset of corner cases.
class TestCIT_KCI(unittest.TestCase):
    def test_Gaussian_dist(self):
        np.random.seed(10)
        X = np.random.randn(300, 1)
        X_prime = np.random.randn(300, 1)
        Y = X + 0.5 * np.random.randn(300, 1)
        Z = Y + 0.5 * np.random.randn(300, 1)
        data = np.hstack((X, X_prime, Y, Z))

        pvalue01 = []
        pvalue03 = []
        pvalue032 = []
        for kernelname in ['Gaussian', 'Polynomial', 'Linear']:
            for est_width in ['empirical', 'median', 'manual']:
                for kwidth in [0.5, 1.0, 2.0]:
                    for use_gp in [True, False]:
                        for approx in [True, False]:
                            for polyd in [1, 2]:
                                cit_CIT = cit.CIT(data, 'kci', kernelX=kernelname, kernelY=kernelname,
                                                  kernelZ=kernelname, est_width=est_width, use_gp=use_gp, approx=approx,
                                                  polyd=polyd, kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth)
                                pvalue01.append(round(cit_CIT(0, 1), 4))
                                # X and X_prime are independent, pvalue01 should be expected larger than 0.01
                                pvalue03.append(round(cit_CIT(0, 3), 4))
                                # X and Z are dependent, pvalue03 should be expected smaller than 0.01
                                pvalue032.append(round(cit_CIT(0, 3, {2}), 4))
                                # X and Z are independent conditional on Y, pvalue032 should be expected larger than
                                # 0.01
        pvalue01_truth = [0.5404, 0.5404, 0.507, 0.501, 0.5404, 0.5404, 0.516, 0.536, 0.5404, 0.5404, 0.506, 0.517,
                          0.5404, 0.5404, 0.526, 0.526, 0.5404, 0.5404, 0.492, 0.507, 0.5404, 0.5404, 0.529, 0.511,
                          0.6106, 0.6106, 0.633, 0.594, 0.6106, 0.6106, 0.612, 0.59, 0.6106, 0.6106, 0.595, 0.59,
                          0.6106, 0.6106, 0.606, 0.589, 0.6106, 0.6106, 0.616, 0.587, 0.6106, 0.6106, 0.595, 0.596,
                          0.5404, 0.5404, 0.522, 0.501, 0.5404, 0.5404, 0.524, 0.53, 0.5864, 0.5864, 0.574, 0.574,
                          0.5864, 0.5864, 0.575, 0.603, 0.4901, 0.4901, 0.487, 0.463, 0.4901, 0.4901, 0.47, 0.493,
                          0.2745, 0.1613, 0.251, 0.167, 0.2745, 0.1613, 0.274, 0.15, 0.2745, 0.1613, 0.272, 0.143,
                          0.2745, 0.1613, 0.276, 0.158, 0.2745, 0.1613, 0.272, 0.149, 0.2745, 0.1613, 0.268, 0.142,
                          0.2745, 0.1613, 0.279, 0.141, 0.2745, 0.1613, 0.279, 0.166, 0.2745, 0.1613, 0.26, 0.152,
                          0.2745, 0.1613, 0.27, 0.16, 0.2745, 0.1613, 0.262, 0.145, 0.2745, 0.1613, 0.291, 0.154,
                          0.2745, 0.1613, 0.254, 0.14, 0.2745, 0.1613, 0.253, 0.16, 0.2745, 0.1613, 0.272, 0.17,
                          0.2745, 0.1613, 0.268, 0.168, 0.2745, 0.1613, 0.285, 0.165, 0.2745, 0.1613, 0.276, 0.147,
                          0.2745, 0.2745, 0.274, 0.272, 0.2745, 0.2745, 0.279, 0.277, 0.2745, 0.2745, 0.299, 0.28,
                          0.2745, 0.2745, 0.263, 0.258, 0.2745, 0.2745, 0.258, 0.269, 0.2745, 0.2745, 0.289, 0.295,
                          0.2745, 0.2745, 0.294, 0.283, 0.2745, 0.2745, 0.286, 0.272, 0.2745, 0.2745, 0.267, 0.273,
                          0.2745, 0.2745, 0.27, 0.276, 0.2745, 0.2745, 0.257, 0.269, 0.2745, 0.2745, 0.274, 0.264,
                          0.2745, 0.2745, 0.249, 0.302, 0.2745, 0.2745, 0.282, 0.259, 0.2745, 0.2745, 0.262, 0.265,
                          0.2745, 0.2745, 0.244, 0.264, 0.2745, 0.2745, 0.295, 0.275, 0.2745, 0.2745, 0.261, 0.265]
        pvalue03_truth = [0.0] * (3 * 3 * 3 * 2 * 2 * 2)
        pvalue032_truth = [0.6087, 0.6087, 0.5956, 0.6, 0.5807, 0.5807, 0.583, 0.5612, 0.6087, 0.6087, 0.5952, 0.5918,
                           0.5807, 0.5807, 0.567, 0.5744, 0.6087, 0.6087, 0.5944, 0.6074, 0.5807, 0.5807, 0.5878,
                           0.5558, 0.6164, 0.6164, 0.6252, 0.628, 0.6179, 0.6179, 0.6158, 0.6076, 0.6164, 0.6164,
                           0.617, 0.6208, 0.6179, 0.6179, 0.6152, 0.6154, 0.6164, 0.6164, 0.6108, 0.6196, 0.6179,
                           0.6179, 0.6384, 0.6198, 0.729, 0.729, 0.7334, 0.7246, 0.6899, 0.6899, 0.6918, 0.6874,
                           0.6079, 0.6079, 0.6016, 0.6068, 0.5938, 0.5938, 0.598, 0.5752, 0.571, 0.571, 0.5638,
                           0.5714, 0.5737, 0.5737, 0.5702, 0.5608, 0.9111, 0.247, 0.9098, 0.2272, 0.9111, 0.247,
                           0.9048, 0.2262, 0.9111, 0.247, 0.9106, 0.2488, 0.9111, 0.247, 0.9106, 0.2312, 0.9111,
                           0.247, 0.9122, 0.224, 0.9111, 0.247, 0.9224, 0.2218, 0.9111, 0.247, 0.9148, 0.222, 0.9111,
                           0.247, 0.9082, 0.216, 0.9111, 0.247, 0.9154, 0.2294, 0.9111, 0.247, 0.9024, 0.2218, 0.9111,
                           0.247, 0.9142, 0.2224, 0.9111, 0.247, 0.9178, 0.2292, 0.9111, 0.247, 0.9098, 0.23, 0.9111,
                           0.247, 0.9192, 0.224, 0.9111, 0.247, 0.9066, 0.2316, 0.9111, 0.247, 0.917, 0.2302, 0.9111,
                           0.247, 0.9134, 0.2392, 0.9111, 0.247, 0.912, 0.2376, 0.9111, 0.9111, 0.8996, 0.9074, 0.9111,
                           0.9111, 0.9124, 0.915, 0.9111, 0.9111, 0.9102, 0.9106, 0.9111, 0.9111, 0.912, 0.9082,
                           0.9111, 0.9111, 0.9134, 0.9104, 0.9111, 0.9111, 0.9196, 0.9114, 0.9111, 0.9111, 0.908,
                           0.912, 0.9111, 0.9111, 0.9114, 0.9116, 0.9111, 0.9111, 0.9074, 0.9066, 0.9111, 0.9111,
                           0.9062, 0.9116, 0.9111, 0.9111, 0.9156, 0.907, 0.9111, 0.9111, 0.9116, 0.9078, 0.9111,
                           0.9111, 0.9052, 0.916, 0.9111, 0.9111, 0.912, 0.9098, 0.9111, 0.9111, 0.9068, 0.9162,
                           0.9111, 0.9111, 0.9098, 0.9098, 0.9111, 0.9111, 0.9132, 0.9136, 0.9111, 0.9111, 0.9136,
                           0.9118]
        self.assertEqual(pvalue01, pvalue01_truth)
        self.assertEqual(pvalue03, pvalue03_truth)
        self.assertEqual(pvalue032, pvalue032_truth)

    def test_Exponential_dist(self):
        np.random.seed(10)
        X = np.random.exponential(size=(300, 1))
        X_prime = np.random.exponential(size=(300, 1))
        Y = X + 0.5 * np.random.exponential(size=(300, 1))
        Z = Y + 0.5 * np.random.exponential(size=(300, 1))
        data = np.hstack((X, X_prime, Y, Z))

        pvalue01 = []
        pvalue03 = []
        pvalue032 = []
        for kernelname in ['Gaussian', 'Polynomial', 'Linear']:
            for est_width in ['empirical', 'median', 'manual']:
                for kwidth in [0.5, 1.0, 2.0]:
                    for use_gp in [True, False]:
                        for approx in [True, False]:
                            for polyd in [1, 2]:
                                cit_CIT = cit.CIT(data, 'kci', kernelX=kernelname, kernelY=kernelname,
                                                  kernelZ=kernelname, est_width=est_width, use_gp=use_gp, approx=approx,
                                                  polyd=polyd, kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth)
                                pvalue01.append(round(cit_CIT(0, 1), 4))
                                # X and X_prime are independent, pvalue01 should be expected larger than 0.01
                                pvalue03.append(round(cit_CIT(0, 3), 4))
                                # X and Z are dependent, pvalue03 should be expected smaller than 0.01
                                pvalue032.append(round(cit_CIT(0, 3, {2}), 4))
                                # X and Z are independent conditional on Y, pvalue032 should be expected larger than
                                # 0.01
        pvalue01_truth = [0.8513, 0.8513, 0.872, 0.873, 0.8513, 0.8513, 0.871, 0.897, 0.8513, 0.8513, 0.879, 0.886,
                          0.8513, 0.8513, 0.889, 0.891, 0.8513, 0.8513, 0.872, 0.876, 0.8513, 0.8513, 0.853, 0.866,
                          0.5809, 0.5809, 0.573, 0.568, 0.5809, 0.5809, 0.548, 0.571, 0.5809, 0.5809, 0.593, 0.588,
                          0.5809, 0.5809, 0.577, 0.577, 0.5809, 0.5809, 0.581, 0.57, 0.5809, 0.5809, 0.596, 0.598,
                          0.8513, 0.8513, 0.866, 0.877, 0.8513, 0.8513, 0.876, 0.874, 0.5604, 0.5604, 0.565, 0.562,
                          0.5604, 0.5604, 0.522, 0.526, 0.5048, 0.5048, 0.48, 0.49, 0.5048, 0.5048, 0.488, 0.49,
                          0.8219, 0.5496, 0.807, 0.553, 0.8219, 0.5496, 0.825, 0.562, 0.8219, 0.5496, 0.801, 0.542,
                          0.8219, 0.5496, 0.823, 0.548, 0.8219, 0.5496, 0.824, 0.549, 0.8219, 0.5496, 0.83, 0.548,
                          0.8219, 0.5496, 0.795, 0.557, 0.8219, 0.5496, 0.818, 0.547, 0.8219, 0.5496, 0.821, 0.57,
                          0.8219, 0.5496, 0.823, 0.539, 0.8219, 0.5496, 0.843, 0.564, 0.8219, 0.5496, 0.823, 0.531,
                          0.8219, 0.5496, 0.802, 0.538, 0.8219, 0.5496, 0.811, 0.544, 0.8219, 0.5496, 0.796, 0.572,
                          0.8219, 0.5496, 0.822, 0.586, 0.8219, 0.5496, 0.818, 0.565, 0.8219, 0.5496, 0.822, 0.569,
                          0.8219, 0.8219, 0.814, 0.811, 0.8219, 0.8219, 0.861, 0.78, 0.8219, 0.8219, 0.85, 0.855,
                          0.8219, 0.8219, 0.815, 0.818, 0.8219, 0.8219, 0.829, 0.818, 0.8219, 0.8219, 0.825, 0.818,
                          0.8219, 0.8219, 0.839, 0.821, 0.8219, 0.8219, 0.83, 0.812, 0.8219, 0.8219, 0.828, 0.83,
                          0.8219, 0.8219, 0.824, 0.806, 0.8219, 0.8219, 0.833, 0.844, 0.8219, 0.8219, 0.824, 0.825,
                          0.8219, 0.8219, 0.827, 0.817, 0.8219, 0.8219, 0.827, 0.826, 0.8219, 0.8219, 0.817, 0.835,
                          0.8219, 0.8219, 0.829, 0.821, 0.8219, 0.8219, 0.832, 0.814, 0.8219, 0.8219, 0.835, 0.8]
        pvalue03_truth = [0.0] * (3 * 3 * 3 * 2 * 2 * 2)
        pvalue032_truth = [0.4088, 0.4088, 0.3792, 0.3764, 0.4076, 0.4076, 0.3732, 0.3746, 0.4088, 0.4088, 0.3834,
                           0.374, 0.4076, 0.4076, 0.3822, 0.375, 0.4088, 0.4088, 0.3702, 0.3806, 0.4076, 0.4076,
                           0.3674, 0.3638, 0.627, 0.627, 0.6232, 0.6236, 0.6756, 0.6756, 0.6788, 0.6806, 0.627, 0.627,
                           0.622, 0.6254, 0.6756, 0.6756, 0.6872, 0.6812, 0.627, 0.627, 0.6196, 0.6076, 0.6756, 0.6756,
                           0.6858, 0.6656, 0.4087, 0.4087, 0.3898, 0.3886, 0.3398, 0.3398, 0.3092, 0.3042, 0.5165,
                           0.5165, 0.4958, 0.4912, 0.5326, 0.5326, 0.5226, 0.5288, 0.8561, 0.8561, 0.8962, 0.8864,
                           0.8749, 0.8749, 0.915, 0.9118, 0.7353, 0.515, 0.735, 0.511, 0.7353, 0.515, 0.7274, 0.507,
                           0.7353, 0.515, 0.737, 0.509, 0.7353, 0.515, 0.731, 0.5084, 0.7353, 0.515, 0.7338, 0.4996,
                           0.7353, 0.515, 0.7312, 0.5156, 0.7353, 0.515, 0.7414, 0.5224, 0.7353, 0.515, 0.7312, 0.519,
                           0.7353, 0.515, 0.737, 0.5046, 0.7353, 0.515, 0.7328, 0.5204, 0.7353, 0.515, 0.738, 0.5058,
                           0.7353, 0.515, 0.728, 0.5016, 0.7353, 0.515, 0.7416, 0.514, 0.7353, 0.515, 0.724, 0.5174,
                           0.7353, 0.515, 0.7342, 0.5118, 0.7353, 0.515, 0.7338, 0.5156, 0.7353, 0.515, 0.7388, 0.5016,
                           0.7353, 0.515, 0.737, 0.5102, 0.7353, 0.7353, 0.7348, 0.738, 0.7353, 0.7353, 0.7392, 0.732,
                           0.7353, 0.7353, 0.7328, 0.7268, 0.7353, 0.7353, 0.737, 0.7398, 0.7353, 0.7353, 0.736, 0.7416,
                           0.7353, 0.7353, 0.7398, 0.7374, 0.7353, 0.7353, 0.7314, 0.737, 0.7353, 0.7353, 0.7338, 0.7354,
                           0.7353, 0.7353, 0.7328, 0.7372, 0.7353, 0.7353, 0.7352, 0.7356, 0.7353, 0.7353, 0.739, 0.7336,
                           0.7353, 0.7353, 0.7404, 0.7386, 0.7353, 0.7353, 0.7398, 0.7432, 0.7353, 0.7353, 0.7386,
                           0.7402, 0.7353, 0.7353, 0.728, 0.7288, 0.7353, 0.7353, 0.7328, 0.7324, 0.7353, 0.7353,
                           0.7412, 0.7398, 0.7353, 0.7353, 0.7412, 0.7272]
        self.assertEqual(pvalue01, pvalue01_truth)
        self.assertEqual(pvalue03, pvalue03_truth)
        self.assertEqual(pvalue032, pvalue032_truth)

    def test_Uniform_dist(self):
        np.random.seed(10)
        X = np.random.uniform(size=(300, 1))
        X_prime = np.random.uniform(size=(300, 1))
        Y = X + 0.5 * np.random.uniform(size=(300, 1))
        Z = Y + 0.5 * np.random.uniform(size=(300, 1))
        data = np.hstack((X, X_prime, Y, Z))

        pvalue01 = []
        pvalue03 = []
        pvalue032 = []
        for kernelname in ['Gaussian', 'Polynomial', 'Linear']:
            for est_width in ['empirical', 'median', 'manual']:
                for kwidth in [0.5, 1.0, 2.0]:
                    for use_gp in [True, False]:
                        for approx in [True, False]:
                            for polyd in [1, 2]:
                                cit_CIT = cit.CIT(data, 'kci', kernelX=kernelname, kernelY=kernelname,
                                                  kernelZ=kernelname, est_width=est_width, use_gp=use_gp, approx=approx,
                                                  polyd=polyd, kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth)
                                pvalue01.append(round(cit_CIT(0, 1), 4))
                                # X and X_prime are independent, pvalue01 should be expected larger than 0.01
                                pvalue03.append(round(cit_CIT(0, 3), 4))
                                # X and Z are dependent, pvalue03 should be expected smaller than 0.01
                                pvalue032.append(round(cit_CIT(0, 3, {2}), 4))
                                # X and Z are independent conditional on Y, pvalue032 should be expected larger than
                                # 0.01
        pvalue01_truth = [0.8099, 0.8099, 0.815, 0.827, 0.8099, 0.8099, 0.821, 0.82, 0.8099, 0.8099, 0.852, 0.828,
                          0.8099, 0.8099, 0.83, 0.825, 0.8099, 0.8099, 0.831, 0.83, 0.8099, 0.8099, 0.813, 0.817,
                          0.7897, 0.7897, 0.809, 0.814, 0.7897, 0.7897, 0.788, 0.789, 0.7897, 0.7897, 0.798, 0.805,
                          0.7897, 0.7897, 0.803, 0.786, 0.7897, 0.7897, 0.804, 0.802, 0.7897, 0.7897, 0.793, 0.794,
                          0.8099, 0.8099, 0.815, 0.815, 0.8099, 0.8099, 0.816, 0.826, 0.5796, 0.5796, 0.551, 0.572,
                          0.5796, 0.5796, 0.565, 0.593, 0.5546, 0.5546, 0.568, 0.546, 0.5546, 0.5546, 0.546, 0.564,
                          0.7155, 0.4159, 0.7, 0.401, 0.7155, 0.4159, 0.719, 0.402, 0.7155, 0.4159, 0.721, 0.419,
                          0.7155, 0.4159, 0.708, 0.412, 0.7155, 0.4159, 0.708, 0.416, 0.7155, 0.4159, 0.704, 0.394,
                          0.7155, 0.4159, 0.71, 0.426, 0.7155, 0.4159, 0.707, 0.419, 0.7155, 0.4159, 0.728, 0.406,
                          0.7155, 0.4159, 0.712, 0.384, 0.7155, 0.4159, 0.718, 0.392, 0.7155, 0.4159, 0.715, 0.365,
                          0.7155, 0.4159, 0.723, 0.387, 0.7155, 0.4159, 0.72, 0.398, 0.7155, 0.4159, 0.705, 0.406,
                          0.7155, 0.4159, 0.704, 0.379, 0.7155, 0.4159, 0.709, 0.392, 0.7155, 0.4159, 0.717, 0.406,
                          0.7155, 0.7155, 0.712, 0.705, 0.7155, 0.7155, 0.719, 0.722, 0.7155, 0.7155, 0.684, 0.715,
                          0.7155, 0.7155, 0.702, 0.705, 0.7155, 0.7155, 0.732, 0.729, 0.7155, 0.7155, 0.688, 0.701,
                          0.7155, 0.7155, 0.729, 0.737, 0.7155, 0.7155, 0.701, 0.722, 0.7155, 0.7155, 0.717, 0.714,
                          0.7155, 0.7155, 0.732, 0.711, 0.7155, 0.7155, 0.709, 0.708, 0.7155, 0.7155, 0.709, 0.708,
                          0.7155, 0.7155, 0.702, 0.723, 0.7155, 0.7155, 0.722, 0.708, 0.7155, 0.7155, 0.72, 0.695,
                          0.7155, 0.7155, 0.7, 0.71, 0.7155, 0.7155, 0.705, 0.73, 0.7155, 0.7155, 0.736, 0.699]
        pvalue03_truth = [0.0] * (3 * 3 * 3 * 2 * 2 * 2)
        pvalue032_truth = [0.6393, 0.6393, 0.6354, 0.6396, 0.6124, 0.6124, 0.5972, 0.618, 0.6393, 0.6393, 0.639,
                           0.6288, 0.6124, 0.6124, 0.607, 0.6064, 0.6393, 0.6393, 0.6438, 0.6442, 0.6124, 0.6124,
                           0.613, 0.595, 0.8899, 0.8899, 0.9352, 0.934, 0.8887, 0.8887, 0.9308, 0.9344, 0.8899, 0.8899,
                           0.935, 0.9354, 0.8887, 0.8887, 0.9336, 0.939, 0.8899, 0.8899, 0.9382, 0.9378, 0.8887, 0.8887,
                           0.9312, 0.9352, 0.5581, 0.5581, 0.5508, 0.5398, 0.5246, 0.5246, 0.5184, 0.514, 0.833, 0.833,
                           0.8634, 0.859, 0.8296, 0.8296, 0.859, 0.8538, 0.8861, 0.8861, 0.9372, 0.9456, 0.8705, 0.8705,
                           0.9246, 0.9294, 0.3721, 0.6544, 0.3794, 0.656, 0.3721, 0.6544, 0.3832, 0.6548, 0.3721, 0.6544,
                           0.3696, 0.6606, 0.3721, 0.6544, 0.373, 0.6516, 0.3721, 0.6544, 0.3814, 0.653, 0.3721, 0.6544,
                           0.3706, 0.6638, 0.3721, 0.6544, 0.373, 0.6624, 0.3721, 0.6544, 0.3728, 0.6566, 0.3721, 0.6544,
                           0.373, 0.6752, 0.3721, 0.6544, 0.3686, 0.6644, 0.3721, 0.6544, 0.3744, 0.6684, 0.3721, 0.6544,
                           0.3758, 0.663, 0.3721, 0.6544, 0.3752, 0.6462, 0.3721, 0.6544, 0.3694, 0.6738, 0.3721, 0.6544,
                           0.3684, 0.6602, 0.3721, 0.6544, 0.3682, 0.6734, 0.3721, 0.6544, 0.3732, 0.662, 0.3721, 0.6544,
                           0.3718, 0.6638, 0.3721, 0.3721, 0.3688, 0.372, 0.3721, 0.3721, 0.3596, 0.3722, 0.3721, 0.3721,
                           0.3662, 0.361, 0.3721, 0.3721, 0.379, 0.3716, 0.3721, 0.3721, 0.379, 0.3684, 0.3721, 0.3721,
                           0.3754, 0.3674, 0.3721, 0.3721, 0.3636, 0.38, 0.3721, 0.3721, 0.3756, 0.3662, 0.3721, 0.3721,
                           0.373, 0.3702, 0.3721, 0.3721, 0.3704, 0.3746, 0.3721, 0.3721, 0.3574, 0.3616, 0.3721, 0.3721,
                           0.359, 0.3702, 0.3721, 0.3721, 0.366, 0.3704, 0.3721, 0.3721, 0.3682, 0.3732, 0.3721, 0.3721,
                           0.3836, 0.3806, 0.3721, 0.3721, 0.371, 0.3596, 0.3721, 0.3721, 0.3728, 0.3744, 0.3721, 0.3721,
                           0.3708, 0.388]
        self.assertEqual(pvalue01, pvalue01_truth)
        self.assertEqual(pvalue03, pvalue03_truth)
        self.assertEqual(pvalue032, pvalue032_truth)

    def test_Mixed_dist(self):
        np.random.seed(10)
        X = np.random.uniform(size=(300, 1))
        X_prime = np.random.randn(300, 1)
        Y = X + 0.5 * np.random.exponential(size=(300, 1))
        Z = Y + 0.5 * np.random.randn(300, 1)
        data = np.hstack((X, X_prime, Y, Z))

        pvalue01 = []
        pvalue03 = []
        pvalue032 = []
        for kernelname in ['Gaussian', 'Polynomial', 'Linear']:
            for est_width in ['empirical', 'median', 'manual']:
                for kwidth in [0.5, 1.0, 2.0]:
                    for use_gp in [True, False]:
                        for approx in [True, False]:
                            for polyd in [1, 2]:
                                cit_CIT = cit.CIT(data, 'kci', kernelX=kernelname, kernelY=kernelname,
                                                  kernelZ=kernelname, est_width=est_width, use_gp=use_gp, approx=approx,
                                                  polyd=polyd, kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth)
                                pvalue01.append(round(cit_CIT(0, 1), 4))
                                # X and X_prime are independent, pvalue01 should be expected larger than 0.01
                                pvalue03.append(round(cit_CIT(0, 3), 4))
                                # X and Z are dependent, pvalue03 should be expected smaller than 0.01
                                pvalue032.append(round(cit_CIT(0, 3, {2}), 4))
                                # X and Z are independent conditional on Y, pvalue032 should be expected larger than
                                # 0.01
        pvalue01_truth = [0.6565, 0.6565, 0.637, 0.668, 0.6565, 0.6565, 0.64, 0.659, 0.6565, 0.6565, 0.646, 0.632,
                          0.6565, 0.6565, 0.67, 0.646, 0.6565, 0.6565, 0.655, 0.668, 0.6565, 0.6565, 0.661, 0.663,
                          0.5346, 0.5346, 0.524, 0.507, 0.5346, 0.5346, 0.517, 0.511, 0.5346, 0.5346, 0.535, 0.514,
                          0.5346, 0.5346, 0.505, 0.526, 0.5346, 0.5346, 0.534, 0.518, 0.5346, 0.5346, 0.517, 0.507,
                          0.6565, 0.6565, 0.633, 0.642, 0.6565, 0.6565, 0.659, 0.64, 0.6557, 0.6557, 0.668, 0.68,
                          0.6557, 0.6557, 0.654, 0.66, 0.6663, 0.6663, 0.701, 0.693, 0.6663, 0.6663, 0.704, 0.698,
                          0.7537, 0.5882, 0.74, 0.618, 0.7537, 0.5882, 0.768, 0.572, 0.7537, 0.5882, 0.778, 0.581,
                          0.7537, 0.5882, 0.755, 0.624, 0.7537, 0.5882, 0.772, 0.569, 0.7537, 0.5882, 0.757, 0.585,
                          0.7537, 0.5882, 0.738, 0.612, 0.7537, 0.5882, 0.77, 0.602, 0.7537, 0.5882, 0.74, 0.562,
                          0.7537, 0.5882, 0.754, 0.609, 0.7537, 0.5882, 0.749, 0.574, 0.7537, 0.5882, 0.775, 0.573,
                          0.7537, 0.5882, 0.76, 0.6, 0.7537, 0.5882, 0.735, 0.613, 0.7537, 0.5882, 0.749, 0.577,
                          0.7537, 0.5882, 0.765, 0.612, 0.7537, 0.5882, 0.758, 0.588, 0.7537, 0.5882, 0.763, 0.57,
                          0.7537, 0.7537, 0.787, 0.743, 0.7537, 0.7537, 0.762, 0.762, 0.7537, 0.7537, 0.769, 0.76,
                          0.7537, 0.7537, 0.746, 0.733, 0.7537, 0.7537, 0.755, 0.723, 0.7537, 0.7537, 0.75, 0.748,
                          0.7537, 0.7537, 0.747, 0.754, 0.7537, 0.7537, 0.774, 0.754, 0.7537, 0.7537, 0.734, 0.767,
                          0.7537, 0.7537, 0.747, 0.772, 0.7537, 0.7537, 0.76, 0.74, 0.7537, 0.7537, 0.722, 0.738,
                          0.7537, 0.7537, 0.732, 0.754, 0.7537, 0.7537, 0.782, 0.745, 0.7537, 0.7537, 0.744, 0.767,
                          0.7537, 0.7537, 0.756, 0.738, 0.7537, 0.7537, 0.718, 0.77, 0.7537, 0.7537, 0.771, 0.757]
        pvalue03_truth = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                          0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        pvalue032_truth = [0.1326, 0.1326, 0.1206, 0.1302, 0.1561, 0.1561, 0.1552, 0.151, 0.1326, 0.1326, 0.119, 0.1292,
                           0.1561, 0.1561, 0.144, 0.1454, 0.1326, 0.1326, 0.131, 0.1238, 0.1561, 0.1561, 0.1488, 0.1484,
                           0.1196, 0.1196, 0.1072, 0.1066, 0.1217, 0.1217, 0.1142, 0.1102, 0.1196, 0.1196, 0.1084,
                           0.1032, 0.1217, 0.1217, 0.1078, 0.112, 0.1196, 0.1196, 0.1092, 0.1102, 0.1217, 0.1217,
                           0.112, 0.114, 0.2332, 0.2332, 0.2248, 0.2156, 0.2847, 0.2847, 0.2602, 0.265, 0.1033, 0.1033,
                           0.0994, 0.1026, 0.0946, 0.0946, 0.0892, 0.1, 0.2605, 0.2605, 0.2266, 0.2212, 0.2599, 0.2599,
                           0.2356, 0.2324, 0.8637, 0.1417, 0.8672, 0.1382, 0.8637, 0.1417, 0.8692, 0.1338, 0.8637,
                           0.1417, 0.8676, 0.1414, 0.8637, 0.1417, 0.8702, 0.1394, 0.8637, 0.1417, 0.8626, 0.1336,
                           0.8637, 0.1417, 0.8614, 0.1354, 0.8637, 0.1417, 0.8666, 0.127, 0.8637, 0.1417, 0.8568,
                           0.1314, 0.8637, 0.1417, 0.8632, 0.1334, 0.8637, 0.1417, 0.863, 0.1386, 0.8637, 0.1417,
                           0.8616, 0.1424, 0.8637, 0.1417, 0.8622, 0.1404, 0.8637, 0.1417, 0.8584, 0.13, 0.8637,
                           0.1417, 0.8584, 0.1382, 0.8637, 0.1417, 0.8748, 0.1234, 0.8637, 0.1417, 0.856, 0.1414,
                           0.8637, 0.1417, 0.8664, 0.1364, 0.8637, 0.1417, 0.8552, 0.1372, 0.8637, 0.8637, 0.858,
                           0.8652, 0.8637, 0.8637, 0.8558, 0.8666, 0.8637, 0.8637, 0.8584, 0.8644, 0.8637, 0.8637,
                           0.8614, 0.8678, 0.8637, 0.8637, 0.8696, 0.8682, 0.8637, 0.8637, 0.869, 0.8624, 0.8637,
                           0.8637, 0.8642, 0.8648, 0.8637, 0.8637, 0.8644, 0.8648, 0.8637, 0.8637, 0.8552, 0.8648,
                           0.8637, 0.8637, 0.8642, 0.86, 0.8637, 0.8637, 0.86, 0.8612, 0.8637, 0.8637, 0.8586, 0.8702,
                           0.8637, 0.8637, 0.8612, 0.8652, 0.8637, 0.8637, 0.8602, 0.8684, 0.8637, 0.8637, 0.8596,
                           0.859, 0.8637, 0.8637, 0.8622, 0.8512, 0.8637, 0.8637, 0.8594, 0.8672, 0.8637, 0.8637,
                           0.8626, 0.8716]
        self.assertEqual(pvalue01, pvalue01_truth)
        self.assertEqual(pvalue03, pvalue03_truth)
        self.assertEqual(pvalue032, pvalue032_truth)
